#!/usr/bin/env python3
"""A script to analyze the dependency relations in a CloudFormation template.

Use like so:

  cdk -a integ.something.js synth | ./template-deps-to-dot | dot -Tpng > deps.png
"""
import collections
import fileinput
import sys
try:
  import yaml
except ImportError:
  print("Please run 'pip3 install pyyaml'")
  sys.exit(1)


def main():
  args = sys.argv[1:]
  if args:
    files = [open(filename) for filename in args]
  else:
    files = [sys.stdin]

  templates = [yaml.safe_load(f) for f in files]

  graph = DepGraph()

  for template in templates:
    if not template:
      sys.stderr.write('Input does not look like a CloudFormation template.\n')
      continue
    parse_template(template, graph)

  graph.render_dot(sys.stdout)


def parse_template(template, graph):
  """Parse template and add all encountered dependencies to the graph."""

  for logical_id, resource_spec in template.get('Resources', {}).items():
    resource_type = resource_spec.get('Type', 'AWS::???::???')
    path = resource_spec.get('Metadata', {}).get('aws:cdk:path', None)
    if path:
      path = resource_name_from_path(path)

    source = '%s\n(%s)' % (path or logical_id, resource_type)

    graph.annotate(logical_id, source)

    for dep in find_property_references(resource_spec.get('Properties', {})):
      if not dep.target.startswith('AWS::'):
        graph.add(logical_id, dep)

    for depends_on in resource_spec.get('DependsOn', []):
      graph.add(logical_id, Dep(depends_on, 'DependsOn'))


def resource_name_from_path(path):
  return '/'.join([p for p in path.split('/') if p != 'Resource'][1:])


def find_property_references(properties):
  """Find references in a resource's Properties.

  Returns:
    list of Dep objects
  """
  ret = []

  def recurse(prop_name, obj):
    if isinstance(obj, list):
      for x in obj:
        recurse(prop_name, x)

    if isinstance(obj, dict):
      ref = parse_reference(obj)
      if ref:
        ret.append(Dep(ref[0], prop_name))
        return

      for key, value in obj.items():
        recurse(prop_name, value)

  for prop_name, prop in properties.items():
    recurse(prop_name, prop)
  return ret


class DepGraph:
  def __init__(self):
    # { source -> [ { dependency, label } ]
    self.graph = collections.defaultdict(set)
    self.annotations = {}
    self.has_incoming = set([])

  def annotate(self, node, annotation):
    self.annotations[node] = annotation

  def add(self, source, dep):
    self.graph[source].add(dep)
    self.has_incoming.add(dep.target)

  def render_dot(self, f):
    """Render a dot version of this graph to the given stream."""
    f.write('digraph G {\n')
    f.write('  rankdir=LR;\n')
    f.write('  node [shape=box];\n')

    for node, annotation in self.annotations.items():
      if node in self.graph or node in self.has_incoming:
        f.write('  %s [label=%s];\n'% (dot_escape(node), fancy_label(annotation)))

    for source, deps in self.graph.items():
      for dep in deps:
        f.write('  %s -> %s [label=%s];\n' % (dot_escape(source), dot_escape(dep.target), dot_escape(dep.label)))

    f.write('}\n')


def fancy_label(s):
  lines = s.split('\n')
  return ('<<FONT POINT-SIZE="14">' + lines[0] + '</FONT>'
    + ''.join('<BR/><FONT POINT-SIZE="10">' + line + '</FONT>' for line in lines[1:])
    + ' >')


def dot_escape(s):
  return '"' + s.replace('\n', '\\n') + '"'


def parse_reference(obj):
  """If this object is an intrinsic reference, return info about it.

  Returns: (logicalId, reference) if reference, None otherwise

  """
  keys = list(obj.keys())
  if keys == ['Ref']:
    return (obj[keys[0]], 'Ref')
  if keys == ['Fn::GetAtt']:
    return (obj[keys[0]][0], obj[keys[0]][1])
  return None



class Dep:
  def __init__(self, target, label):
    self.target = target
    self.label = label

  def __eq__(self, rhs):
    return isinstance(rhs, Dep) and self.target == rhs.target and self.label == rhs.label

  def __ne__(self, rhs):
    return not (self == rhs)

  def __hash__(self):
    return hash((self.target, self.label))



if __name__ == '__main__':
  main()
